from typing import Dict, Tuple, Union

import torch
import torch.nn as nn

from CRDR.src.utils.registry import TRAINER_REGISTRY

from .multirate_hr_rgan_rate_distortion_trainer import MultirateHighRateRGANRateDistortionTrainer

@TRAINER_REGISTRY.register()
class MultirateBetaCondHrrGanRateDistortionTrainer(MultirateHighRateRGANRateDistortionTrainer):

    def optimize_parameters(self, current_iter: int, data_dict: Dict) -> Dict:
        log_dict = {}

        ###################################################################
        #                             Train G                             
        ###################################################################
        self.discriminator.requires_grad_(False)
        self.g_optimizer.zero_grad()
        if self.aux_optimizer:
            self.aux_optimizer.zero_grad()

        # run model
        real_images, fake_images, bpp, other_outputs = self.run_comp_model(data_dict)
        rate_ind = other_outputs['rate_ind']
        beta = other_outputs['beta']

        log_dict['qbpp'] = other_outputs.get('qbpp', -1)

        # create image to calculate Relative score
        high_rate_level = rate_ind + self.relative_score_rate_delta
        if high_rate_level > self.rate_level - 1:
            relative_score_images = real_images
        else:
            data_dict['rate_ind'] = high_rate_level
            data_dict['beta'] = beta
            with torch.no_grad():
                _, relative_score_images, _, _ = self.run_comp_model(data_dict)


        # calculate losses
        g_loss_dict = {}
        dist_loss = self.distortion_loss(real_images, fake_images, **other_outputs)
        g_loss_dict['distortion'] = dist_loss

        rate_loss = self.rate_loss(bpp, **other_outputs, current_iter=current_iter)
        g_loss_dict['rate'] = rate_loss

        ## !! beta is applied to perceptual_loss and adv_loss !!
        assert self.perceptual_loss
        percep_loss = self.perceptual_loss(real_images, fake_images)
        g_loss_dict['perceptual'] = percep_loss

        # RGAN adv loss
        real_d_pred = self.discriminator(relative_score_images.detach(), **other_outputs).detach() ## calc relative score
        fake_g_pred = self.discriminator(fake_images, **other_outputs)

        l_g_real = self.gan_loss(real_d_pred - fake_g_pred, is_real=False, is_disc=False)
        l_g_fake = self.gan_loss(fake_g_pred - real_d_pred, is_real=True, is_disc=False)
        adv_loss = (l_g_real + l_g_fake) / 2
        g_loss_dict['adv'] = adv_loss

        l_total = dist_loss + rate_loss + beta * (percep_loss + adv_loss)

        # For stability
        if (loss_anomaly := self.check_loss_nan_inf(l_total)):
            self.logger.warning(f'iter{current_iter}: skipped because loss is {loss_anomaly}')
            return # skip back-propagation part

        # back prop & update parameters
        l_total.backward()
        if self.opt.optim.get('clip_max_norm'):
            nn.utils.clip_grad_norm_(self.comp_model.parameters(), self.opt.optim.clip_max_norm)
        self.g_optimizer.step()

        if self.aux_optimizer:
            log_dict['aux'] = self.optimize_aux_parameters()

        if self.g_scheduler:
            self.g_scheduler.step()

        log_dict.update(g_loss_dict)

        ###################################################################
        #                             Train D                             
        ###################################################################
        self.discriminator.requires_grad_(True)
        self.d_optimizer.zero_grad()

        # real
        fake_d_pred = self.discriminator(fake_images, **other_outputs).detach()
        real_d_pred = self.discriminator(real_images, **other_outputs)
        l_d_real = self.gan_loss(real_d_pred - fake_d_pred, is_real=True, is_disc=True) * 0.5
        l_d_real.backward()

        # fake
        fake_d_pred = self.discriminator(fake_images.detach(), **other_outputs)
        l_d_fake = self.gan_loss(fake_d_pred - real_d_pred.detach(), is_real=False, is_disc=True) * 0.5
        l_d_fake.backward()

        log_dict.update({
            'd_real': l_d_real,
            'd_fake': l_d_fake,
            'd_total': l_d_real + l_d_fake,
            'out_d_real': torch.mean(real_d_pred.detach()),
            'out_d_fake': torch.mean(fake_d_pred.detach()),
        })

        self.d_optimizer.step()
        if self.d_scheduler:
            self.d_scheduler.step()

        return log_dict